{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Example Usages of NAS Benchmarks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "import pprint\n",
        "import time\n",
        "\n",
        "from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats\n",
        "from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats\n",
        "from nni.nas.benchmarks.nds import query_nds_trial_stats\n",
        "\n",
        "ti = time.time()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NAS-Bench-101"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use the following architecture as an example:\n",
        "\n",
        "![nas-101](../../img/nas-bench-101-example.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "arch = {\n",
        "    'op1': 'conv3x3-bn-relu',\n",
        "    'op2': 'maxpool3x3',\n",
        "    'op3': 'conv3x3-bn-relu',\n",
        "    'op4': 'conv3x3-bn-relu',\n",
        "    'op5': 'conv1x1-bn-relu',\n",
        "    'input1': [0],\n",
        "    'input2': [1],\n",
        "    'input3': [2],\n",
        "    'input4': [0],\n",
        "    'input5': [0, 3, 4],\n",
        "    'input6': [2, 5]\n",
        "}\n",
        "for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n",
        "    pprint.pprint(t)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "An architecture of NAS-Bench-101 could be trained more than once. Each element of the returned generator is a dict which contains one of the training results of this trial config (architecture + hyper-parameters) including train/valid/test accuracy, training time, number of epochs, etc. The results of NAS-Bench-201 and NDS follow similar formats."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NAS-Bench-201"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use the following architecture as an example:\n",
        "\n",
        "![nas-201](../../img/nas-bench-201-example.png)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "arch = {\n",
        "    '0_1': 'avg_pool_3x3',\n",
        "    '0_2': 'conv_1x1',\n",
        "    '1_2': 'skip_connect',\n",
        "    '0_3': 'conv_1x1',\n",
        "    '1_3': 'skip_connect',\n",
        "    '2_3': 'skip_connect'\n",
        "}\n",
        "for t in query_nb201_trial_stats(arch, 200, 'cifar100'):\n",
        "    pprint.pprint(t)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Intermediate results are also available."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n",
        "    print(t['config'])\n",
        "    print('Intermediates:', len(t['intermediates']))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NDS"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use the following architecture as an example:<br>\n",
        "![nds](../../img/nas-bench-nds-example.png)\n",
        "\n",
        "Here, `bot_muls`, `ds`, `num_gs`, `ss` and `ws` stand for \"bottleneck multipliers\", \"depths\", \"number of groups\", \"strides\" and \"widths\" respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "model_spec = {\n",
        "    'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
        "    'ds': [1, 16, 1, 4],\n",
        "    'num_gs': [1, 2, 1, 2],\n",
        "    'ss': [1, 1, 2, 2],\n",
        "    'ws': [16, 64, 128, 16]\n",
        "}\n",
        "# Use none as a wildcard\n",
        "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10'):\n",
        "    pprint.pprint(t)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "model_spec = {\n",
        "    'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
        "    'ds': [1, 16, 1, 4],\n",
        "    'num_gs': [1, 2, 1, 2],\n",
        "    'ss': [1, 1, 2, 2],\n",
        "    'ws': [16, 64, 128, 16]\n",
        "}\n",
        "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n",
        "    pprint.pprint(t['intermediates'][:10])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "model_spec = {'ds': [1, 12, 12, 12], 'ss': [1, 1, 2, 2], 'ws': [16, 24, 24, 40]}\n",
        "for t in query_nds_trial_stats('residual_basic', 'resnet', 'random', model_spec, {}, 'cifar10'):\n",
        "    pprint.pprint(t)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# get the first one\n",
        "pprint.pprint(next(query_nds_trial_stats('vanilla', None, None, None, None, None)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# count number\n",
        "model_spec = {'num_nodes_normal': 5, 'num_nodes_reduce': 5, 'depth': 12, 'width': 32, 'aux': False, 'drop_prob': 0.0}\n",
        "cell_spec = {\n",
        "    'normal_0_op_x': 'avg_pool_3x3',\n",
        "    'normal_0_input_x': 0,\n",
        "    'normal_0_op_y': 'conv_7x1_1x7',\n",
        "    'normal_0_input_y': 1,\n",
        "    'normal_1_op_x': 'sep_conv_3x3',\n",
        "    'normal_1_input_x': 2,\n",
        "    'normal_1_op_y': 'sep_conv_5x5',\n",
        "    'normal_1_input_y': 0,\n",
        "    'normal_2_op_x': 'dil_sep_conv_3x3',\n",
        "    'normal_2_input_x': 2,\n",
        "    'normal_2_op_y': 'dil_sep_conv_3x3',\n",
        "    'normal_2_input_y': 2,\n",
        "    'normal_3_op_x': 'skip_connect',\n",
        "    'normal_3_input_x': 4,\n",
        "    'normal_3_op_y': 'dil_sep_conv_3x3',\n",
        "    'normal_3_input_y': 4,\n",
        "    'normal_4_op_x': 'conv_7x1_1x7',\n",
        "    'normal_4_input_x': 2,\n",
        "    'normal_4_op_y': 'sep_conv_3x3',\n",
        "    'normal_4_input_y': 4,\n",
        "    'normal_concat': [3, 5, 6],\n",
        "    'reduce_0_op_x': 'avg_pool_3x3',\n",
        "    'reduce_0_input_x': 0,\n",
        "    'reduce_0_op_y': 'dil_sep_conv_3x3',\n",
        "    'reduce_0_input_y': 1,\n",
        "    'reduce_1_op_x': 'sep_conv_3x3',\n",
        "    'reduce_1_input_x': 0,\n",
        "    'reduce_1_op_y': 'sep_conv_3x3',\n",
        "    'reduce_1_input_y': 0,\n",
        "    'reduce_2_op_x': 'skip_connect',\n",
        "    'reduce_2_input_x': 2,\n",
        "    'reduce_2_op_y': 'sep_conv_7x7',\n",
        "    'reduce_2_input_y': 0,\n",
        "    'reduce_3_op_x': 'conv_7x1_1x7',\n",
        "    'reduce_3_input_x': 4,\n",
        "    'reduce_3_op_y': 'skip_connect',\n",
        "    'reduce_3_input_y': 4,\n",
        "    'reduce_4_op_x': 'conv_7x1_1x7',\n",
        "    'reduce_4_input_x': 0,\n",
        "    'reduce_4_op_y': 'conv_7x1_1x7',\n",
        "    'reduce_4_input_y': 5,\n",
        "    'reduce_concat': [3, 6]\n",
        "}\n",
        "\n",
        "for t in query_nds_trial_stats('nas_cell', None, None, model_spec, cell_spec, 'cifar10'):\n",
        "    assert t['config']['model_spec'] == model_spec\n",
        "    assert t['config']['cell_spec'] == cell_spec\n",
        "    pprint.pprint(t)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# count number\n",
        "print('NDS (amoeba) count:', len(list(query_nds_trial_stats(None, 'amoeba', None, None, None, None, None))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NLP"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "metadata": false
        }
      },
      "source": [
        "Use the following two architectures as examples. \n",
        "The arch in the paper is called \"receipe\" with nested variable, and now it is nunested in the benchmarks for NNI.\n",
        "An arch has multiple Node, Node_input_n and Node_op, you can refer to doc for more details.\n",
        "\n",
        "arch1 : <img src=\"../../img/nas-bench-nlp-example1.jpeg\" width=400 height=300 /> \n",
        "\n",
        "\n",
        "arch2 : <img src=\"../../img/nas-bench-nlp-example2.jpeg\" width=400 height=300 /> \n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{'config': {'arch': {'h_new_0_input_0': 'node_3',\n                     'h_new_0_input_1': 'node_2',\n                     'h_new_0_input_2': 'node_1',\n                     'h_new_0_op': 'blend',\n                     'node_0_input_0': 'x',\n                     'node_0_input_1': 'h_prev_0',\n                     'node_0_op': 'linear',\n                     'node_1_input_0': 'node_0',\n                     'node_1_op': 'activation_tanh',\n                     'node_2_input_0': 'h_prev_0',\n                     'node_2_input_1': 'node_1',\n                     'node_2_input_2': 'x',\n                     'node_2_op': 'linear',\n                     'node_3_input_0': 'node_2',\n                     'node_3_op': 'activation_leaky_relu'},\n            'dataset': 'ptb',\n            'id': 20003},\n 'id': 16291,\n 'test_loss': 4.680262297102549,\n 'train_loss': 4.132040537087838,\n 'training_time': 177.05208373069763,\n 'val_loss': 4.707944253177966}\n"
          ]
        }
      ],
      "source": [
        "import pprint\n",
        "from nni.nas.benchmarks.nlp import query_nlp_trial_stats\n",
        "\n",
        "arch1 = {'h_new_0_input_0': 'node_3', 'h_new_0_input_1': 'node_2', 'h_new_0_input_2': 'node_1', 'h_new_0_op': 'blend', 'node_0_input_0': 'x', 'node_0_input_1': 'h_prev_0', 'node_0_op': 'linear','node_1_input_0': 'node_0', 'node_1_op': 'activation_tanh', 'node_2_input_0': 'h_prev_0', 'node_2_input_1': 'node_1', 'node_2_input_2': 'x', 'node_2_op': 'linear', 'node_3_input_0': 'node_2', 'node_3_op': 'activation_leaky_relu'}\n",
        "for i in query_nlp_trial_stats(arch=arch1, dataset=\"ptb\"):\n",
        "    pprint.pprint(i)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[{'current_epoch': 46,\n  'id': 1796,\n  'test_loss': 6.233430054978619,\n  'train_loss': 6.4866799231542664,\n  'training_time': 146.5680329799652,\n  'val_loss': 6.326836978687959},\n {'current_epoch': 47,\n  'id': 1797,\n  'test_loss': 6.2402057403023825,\n  'train_loss': 6.485401405247535,\n  'training_time': 146.05511450767517,\n  'val_loss': 6.3239741605870865},\n {'current_epoch': 48,\n  'id': 1798,\n  'test_loss': 6.351145308363877,\n  'train_loss': 6.611281181173992,\n  'training_time': 145.8849437236786,\n  'val_loss': 6.436160816865809},\n {'current_epoch': 49,\n  'id': 1799,\n  'test_loss': 6.227155079159031,\n  'train_loss': 6.473414458249545,\n  'training_time': 145.51414465904236,\n  'val_loss': 6.313294354607077}]\n"
          ]
        }
      ],
      "source": [
        "arch2 = {\"h_new_0_input_0\":\"node_0\",\"h_new_0_input_1\":\"node_1\",\"h_new_0_op\":\"elementwise_sum\",\"node_0_input_0\":\"x\",\"node_0_input_1\":\"h_prev_0\",\"node_0_op\":\"linear\",\"node_1_input_0\":\"node_0\",\"node_1_op\":\"activation_tanh\"}\n",
        "for i in query_nlp_trial_stats(arch=arch2, dataset='wikitext-2', include_intermediates=True):\n",
        "    pprint.pprint(i['intermediates'][45:49])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "pycharm": {},
        "tags": []
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Elapsed time:  5.60982608795166 seconds\n"
          ]
        }
      ],
      "source": [
        "print('Elapsed time: ', time.time() - ti, 'seconds')"
      ]
    }
  ],
  "metadata": {
    "file_extension": ".py",
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "name": "python",
      "version": "3.8.5-final"
    },
    "mimetype": "text/x-python",
    "name": "python",
    "npconvert_exporter": "python",
    "orig_nbformat": 2,
    "pygments_lexer": "ipython3",
    "version": 3
  },
  "nbformat": 4,
  "nbformat_minor": 2
}